/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package studio.raptor.ddal.core.engine.plan.node.impl.execute;

import java.sql.SQLException;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import studio.raptor.ddal.common.exception.ExecuteException;
import studio.raptor.ddal.common.exception.ExecuteException.Code;
import studio.raptor.ddal.common.util.StringUtil;
import studio.raptor.ddal.config.config.SystemProperties;
import studio.raptor.ddal.core.connection.BackendConnection;
import studio.raptor.ddal.core.connection.BackendDataSourceManager;
import studio.raptor.ddal.core.connection.ContextConnectionWrapper;
import studio.raptor.ddal.core.engine.ProcessContext;
import studio.raptor.ddal.core.engine.plan.node.ProcessNode;
import studio.raptor.ddal.core.executor.ExecutionUnit;
import studio.raptor.ddal.core.executor.strategy.ReadWriteStrategy;

/**
 * 数据源选择。
 *
 * @author Sam
 * @since 3.0.0
 */
public class DataSourceSelector extends ProcessNode {

  private static Logger logger = LoggerFactory.getLogger(DataSourceSelector.class);
  /**
   * 读写控制策略。若未指定，则使用返回false的默认读写控制策略。
   */
  private static ReadWriteStrategy readWriteStrategy;

  @Override
  protected void execute(ProcessContext context) {
    Map<String, ContextConnectionWrapper> connectionsWrapper = context.getShardBackendConnWrapper();
    for (ExecutionUnit unit : context.getCurrentExecutionGroup().getExecutionUnits()) {
      ContextConnectionWrapper curShardConnWrapper = connectionsWrapper
          .get(unit.getShard().getName());
      if (null == curShardConnWrapper) {
        curShardConnWrapper = new ContextConnectionWrapper();
        connectionsWrapper.put(unit.getShard().getName(), curShardConnWrapper);
      }
      BackendConnection curBConn;
      if (readWriteStrategy.isReadOnly() || context.hasReadonlyHint()) {
        if (null == (curBConn = curShardConnWrapper.getReadonlyConnection())) {
          try {
            curShardConnWrapper.setReadonlyConnection(
                (curBConn = BackendDataSourceManager
                    .getBackendConnection(unit.getShard().getDsGroup(), true,
                        context.isAutoCommit()))
            );
          } catch (SQLException e) {
            throw ExecuteException.create(Code.GET_READONLY_CONNECTION_FAILED_ERROR);
          }
        }
      } else {
        if (null == (curBConn = curShardConnWrapper.getReadWriteConnection())) {
          try {
            curShardConnWrapper.setReadWriteConnection(
                (curBConn = BackendDataSourceManager
                    .getBackendConnection(unit.getShard().getDsGroup(), false,
                        context.isAutoCommit()))
            );
          } catch (SQLException e) {
            throw ExecuteException.create(Code.GET_READWRITE_CONNECTION_FAILED_ERROR);
          }
        }
      }
      curShardConnWrapper.setCurrentConnection(curBConn);
    }
  }

  static {
    String strategyClazz = null;
    Class<?> strategyClass;
    try {
      if (null != SystemProperties.getInstance() && null != SystemProperties.getInstance()
          .getMapper()) {
        strategyClazz = SystemProperties.getInstance().getMapper().get("strategy.readwrite");
      }
      strategyClass = StringUtil.isEmpty(strategyClazz) ? null : Class.forName(strategyClazz);

      if (null == strategyClass || !ReadWriteStrategy.class.isAssignableFrom(strategyClass)) {
        readWriteStrategy = new ReadWriteStrategy() {
          @Override
          public boolean isReadOnly() {
            return false;
          }
        };
        logger.info("Default readwrite strategy will be used instead of unknown strategy {}",
            strategyClazz);
      } else {
        try {
          readWriteStrategy = (ReadWriteStrategy) strategyClass.newInstance();
        } catch (InstantiationException | IllegalAccessException e) {
          throw ExecuteException.create(Code.READ_WRITE_STRATEGY_INSTANCE_ERROR);
        }
      }
    } catch (ClassNotFoundException | ExceptionInInitializerError ex) {
      throw ExecuteException.create(Code.READ_WRITE_STRATEGY_CONFIG_ERROR);
    }
  }
}
